from haven import haven_utils as hu


pmlb_datasets = ['analcatdata_lawsuit',
'australian',
'backache',
'biomed',
'breast_cancer_wisconsin',
'breast_cancer',
'breast_w',
'breast',
'buggyCrx',
'bupa',
'chess',
'churn',
'clean1',
'cleve',
'colic',
'corral',
'credit_a',
'credit_g',
'crx',
'diabetes',
'dis',
'flare',
'GAMETES_Epistasis_2_Way_1000atts_0.4H_EDM_1_EDM_1_1',
'GAMETES_Epistasis_2_Way_20atts_0.1H_EDM_1_1',
'GAMETES_Epistasis_2_Way_20atts_0.4H_EDM_1_1',
'GAMETES_Epistasis_3_Way_20atts_0.2H_EDM_1_1',
'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_50_EDM_2_001',
'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_75_EDM_2_001',
'german',
'glass2',
'haberman',
'heart_c',
'heart_h',
'heart_statlog',
'hepatitis',
'Hill_Valley_with_noise',
'Hill_Valley_without_noise',
'horse_colic',
'house_votes_84',
'hungarian',
'hypothyroid',
'ionosphere',
'irish',
'kr_vs_kp',
'mofn_3_7_10',
'monk1',
'monk2',
'monk3',
'parity5+5',
'pima',
'prnn_crabs',
'prnn_synth',
'profb',
'saheart',
'sonar',
'spambase',
'spect',
'spectf',
'threeOf9',
'tic_tac_toe',
'tokyo1',
'vote',
'wdbc',
'xd6']

run_list = [0,1,2,3,4]

EXP_GROUPS = {
    "pmlb_small": {
        "dataset": ["glass2"],
        "model": ["linear"],
        "loss_func": ['logistic_loss'],
        "acc_func": ["logistic_accuracy"],
        "opt": ['nls', 'sls'],
        "batch_size": [128],
        "max_epoch": [100],
        "runs": run_list
    },

    "pmlb": {
        "dataset": pmlb_datasets,
        "model": ["linear"],
        "loss_func": ['logistic_loss'],
        "acc_func": ["logistic_accuracy"],
        "opt": ['nls', 'sls'],
        "batch_size": [128],
        "max_epoch": [100],
        "runs": run_list
    },

    "pmlb_gam": {
        "dataset": ['GAMETES_Epistasis_2_Way_1000atts_0.4H_EDM_1_EDM_1_1',
                    'GAMETES_Epistasis_2_Way_20atts_0.1H_EDM_1_1',
                    'GAMETES_Epistasis_2_Way_20atts_0.4H_EDM_1_1',
                    'GAMETES_Epistasis_3_Way_20atts_0.2H_EDM_1_1',
                    'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_50_EDM_2_001',
                    'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_75_EDM_2_001'],
        "model": ["linear"],
        "loss_func": ['logistic_loss'],
        "acc_func": ["logistic_accuracy"],
        "opt": ['nls', 'sls'],
        "batch_size": [128],
        "max_epoch": [100],
        "runs": run_list
    },

    "mnist_nls": {"dataset":["mnist"],
            "model":["mlp"],
            "loss_func": ["softmax_loss"],
            "opt":[{"name":"sgd_armijo", "gamma":2}, {"name":"adam"}, {"name":"nls"}],
            "acc_func":["softmax_accuracy"],
            "batch_size":[128],
            "max_epoch":[100],
            "runs":[0,1,2]},

    "mnist_nls_small": {"dataset":["mnist"],
            "model":["small_nn"],
            "loss_func": ["softmax_loss"],
            "opt":[{"name": "nls", "gamma_decr": 0.7, "gamma_incr": 1.25},
                   {"name": "nls", "gamma_decr": 0.9, "gamma_incr": 1.01},
                   {"name": "sgd_armijo", "gamma": 2}],
            #"opt":[{"name":"nls"}],
            "acc_func":["softmax_accuracy"],
            "batch_size":[128],
            "max_epoch":[50],
            "num_iterations_per_test": [500],
            "max_iteration": [2000],
            "runs":[0]},

    "cifar10_small": {"dataset": ["cifar10"],
                "model": ["resnet34"],
                "loss_func": ["softmax_loss"],
                "opt": [{"name":"sgd_armijo", "gamma":2}, {"name":"nls"}],
                "acc_func": ["softmax_accuracy"],
                "batch_size": [128],
                "max_epoch": [50],
                "num_iterations_per_test": [500],
                "max_iteration": [25000],
                "runs": [0, 1, 2]},

    "mnist_nls_large": {"dataset":["mnist"],
            "model":["mlp"],
            "loss_func": ["softmax_loss"],
            "opt":[{"name": "nls", "gamma_decr": 0.7, "gamma_incr": 1.0/0.7},
                   {"name": "nls", "gamma_decr": 0.9, "gamma_incr": 1.0/0.9},
                   {"name": "sgd_armijo", "gamma": 2}],
            "acc_func":["softmax_accuracy"],
            "batch_size":[128],
            "max_epoch": [200],
            "num_iterations_per_test": [500],
            "max_iteration": [80000],
            "runs": [0,1,2,3]},

    "mnist_nls_conv": {"dataset": ["mnist"],
            "model":["conv_nn"],
            "loss_func": ["softmax_loss"],
            "opt":[{"name": "nls", "gamma_decr": 0.7, "gamma_incr": 1.0/0.7},
                   {"name": "nls", "gamma_decr": 0.9, "gamma_incr": 1.0/0.9},
                   {"name": "sgd_armijo", "gamma": 2}],
            "acc_func":["softmax_accuracy"],
            "batch_size":[128],
            "max_epoch":[200],
            "num_iterations_per_test": [500],
            "max_iteration": [80000],
            "runs":[0,1,2,3]},

            }


EXP_GROUPS = {k: hu.cartesian_exp_group(v) for k, v in EXP_GROUPS.items()}
